#!/usr/bin/env python3
#
# train an SSD detection model on Pascal VOC or Open Images datasets
# https://github.com/dusty-nv/jetson-inference/blob/master/docs/pytorch-ssd.md
#
import os
import sys
import logging
import datetime
import torch
import torch_npu
from torch_npu.npu import amp

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR

from vision.utils.misc import Timer
from vision.ssd.ssd import MatchPrior
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd
from vision.dataset import VOCDataset
from vision.nn.multibox_loss import MultiboxLoss
from vision.ssd.config import mobilenetv1_ssd_config
from vision.ssd.data_preprocessing import TrainAugmentation, TestTransform


DEFAULT_PRETRAINED_MODEL='models/mobilenet-v1-ssd-mp-0_675.pth'

logging.basicConfig(stream=sys.stdout, level=getattr(logging, "INFO", logging.INFO),
                    format='%(asctime)s - %(message)s', datefmt="%Y-%m-%d %H:%M:%S")
# make sure that the checkpoint output dir exists
checkpoint_folder = "models"
checkpoint_folder = os.path.expanduser(checkpoint_folder)  
if not os.path.exists(checkpoint_folder):
    os.mkdir(checkpoint_folder)    
tensorboard = SummaryWriter(log_dir=os.path.join(checkpoint_folder, "tensorboard", f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"))

DEVICE = torch.device("npu:0")

def train(loader, net, criterion, optimizer, device, scaler, debug_steps=100, epoch=-1):
    net.train(True)
    
    train_loss = 0.0
    train_regression_loss = 0.0
    train_classification_loss = 0.0
    
    running_loss = 0.0
    running_regression_loss = 0.0
    running_classification_loss = 0.0
    
    num_batches = 0
    
    for i, data in enumerate(loader):
        images, boxes, labels = data
        images = images.to(device)
        boxes = boxes.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        with amp.autocast():
            confidence, locations = net(images)
            regression_loss, classification_loss = criterion(confidence, locations, labels, boxes)
            loss = regression_loss + classification_loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        train_regression_loss += regression_loss.item()
        train_classification_loss += classification_loss.item()
        
        running_loss += loss.item()
        running_regression_loss += regression_loss.item()
        running_classification_loss += classification_loss.item()

        if i and i % debug_steps == 0:
            avg_loss = running_loss / debug_steps
            avg_reg_loss = running_regression_loss / debug_steps
            avg_clf_loss = running_classification_loss / debug_steps
            logging.info(
                f"Epoch: {epoch}, Step: {i}/{len(loader)}, " +
                f"Avg Loss: {avg_loss:.4f}, " +
                f"Avg Regression Loss {avg_reg_loss:.4f}, " +
                f"Avg Classification Loss: {avg_clf_loss:.4f}"
            )
            running_loss = 0.0
            running_regression_loss = 0.0
            running_classification_loss = 0.0

        num_batches += 1
        
    train_loss /= num_batches
    train_regression_loss /= num_batches
    train_classification_loss /= num_batches
    
    logging.info(
        f"Epoch: {epoch}, " +
        f"Training Loss: {train_loss:.4f}, " +
        f"Training Regression Loss {train_regression_loss:.4f}, " +
        f"Training Classification Loss: {train_classification_loss:.4f}"
    )
     
    tensorboard.add_scalar('Loss/train', train_loss, epoch)
    tensorboard.add_scalar('Regression Loss/train', train_regression_loss, epoch)
    tensorboard.add_scalar('Classification Loss/train', train_classification_loss, epoch)

def test(loader, net, criterion, device):
    net.eval()
    running_loss = 0.0
    running_regression_loss = 0.0
    running_classification_loss = 0.0
    num = 0
    for _, data in enumerate(loader):
        images, boxes, labels = data
        images = images.to(device)
        boxes = boxes.to(device)
        labels = labels.to(device)
        num += 1
        with torch.no_grad():
            with amp.autocast():
                confidence, locations = net(images)
                regression_loss, classification_loss = criterion(confidence, locations, labels, boxes)
                loss = regression_loss + classification_loss

        running_loss += loss.item()
        running_regression_loss += regression_loss.item()
        running_classification_loss += classification_loss.item()
    
    return running_loss / num, running_regression_loss / num, running_classification_loss / num

if __name__ == '__main__':
    
    timer = Timer()            
    create_net = create_mobilenetv1_ssd
    config = mobilenetv1_ssd_config
    config.set_image_size(300)
        
    # create data transforms for train/test/val
    train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std)
    target_transform = MatchPrior(config.priors, config.center_variance,
                                  config.size_variance, 0.5)

    test_transform = TestTransform(config.image_size, config.image_mean, config.image_std)
    dataset_path = "dataset"
    batch_size = 4
    num_workers = 3
    # load datasets (could be multiple)
    logging.info("Prepare training datasets.")
    train_dataset = VOCDataset(dataset_path, transform=train_transform,
                         target_transform=target_transform)
    num_classes = len(train_dataset.class_names)      
    # create training dataset
    logging.info("Train dataset size: {}".format(len(train_dataset)))
    train_loader = DataLoader(train_dataset, batch_size,
                              num_workers=num_workers,
                              shuffle=True)
                           
    # create validation dataset
    val_dataset = VOCDataset(dataset_path, transform=test_transform,
                                 target_transform=target_transform, is_test=True)
    val_loader = DataLoader(val_dataset, batch_size,
                            num_workers=num_workers,
                            shuffle=False)
                      
    # create the network
    logging.info("Build network.")
    net = create_net(num_classes)
    last_epoch = -1
        
    # load a previous model checkpoint (if requested)
    timer.start("Load Model")
    
    logging.info(f"Init from pretrained SSD {DEFAULT_PRETRAINED_MODEL}")
    
    if not os.path.exists(DEFAULT_PRETRAINED_MODEL):
        os.system(f"wget --quiet --show-progress --progress=bar:force:noscroll --no-check-certificate https://nvidia.box.com/shared/static/djf5w54rjvpqocsiztzaandq1m3avr7c.pth -O {DEFAULT_PRETRAINED_MODEL}")

    net.init_from_pretrained_ssd(DEFAULT_PRETRAINED_MODEL)
        
    logging.info(f'Took {timer.end("Load Model"):.2f} seconds to load the model.')

    # move the model to GPU
    net.to(DEVICE)

    # define loss function and optimizer
    criterion = MultiboxLoss(config.priors, iou_threshold=0.5, neg_pos_ratio=3,
                             center_variance=0.1, size_variance=0.2, device=DEVICE)
                             
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9,
                                weight_decay=5e-4)
    scaler = amp.GradScaler()
    logging.info("Uses CosineAnnealingLR scheduler.")
    scheduler = CosineAnnealingLR(optimizer, 100, last_epoch=last_epoch)

    # train for the desired number of epochs
    logging.info(f"Start training from epoch {last_epoch + 1}.")
    num_epochs = 100
    best_loss = 10000
    model_path = os.path.join(checkpoint_folder, "best.pth")
    for epoch in range(last_epoch + 1, num_epochs):
        train(train_loader, net, criterion, optimizer, device=DEVICE, scaler=scaler, debug_steps=10, epoch=epoch)
        scheduler.step()
        val_loss, val_regression_loss, val_classification_loss = test(val_loader, net, criterion, DEVICE)
        
        logging.info(
            f"Epoch: {epoch}, " +
            f"Validation Loss: {val_loss:.4f}, " +
            f"Validation Regression Loss {val_regression_loss:.4f}, " +
            f"Validation Classification Loss: {val_classification_loss:.4f}"
        )
                
        tensorboard.add_scalar('Loss/val', val_loss, epoch)
        tensorboard.add_scalar('Regression Loss/val', val_regression_loss, epoch)
        tensorboard.add_scalar('Classification Loss/val', val_classification_loss, epoch)
        
        if val_loss < best_loss:
            best_loss = val_loss
            net.save(model_path)
            logging.info(f"Saved model {model_path}")
        
        
    logging.info("Task done, exiting program.")
    tensorboard.close()